{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vY4SK0xKAJgm"
   },
   "source": [
    "# Bidirectional Multi-layer RNN with LSTM with Own Dataset in CSV Format (AG News)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Dataset Description\n",
    "\n",
    "```\n",
    "AG's News Topic Classification Dataset\n",
    "\n",
    "Version 3, Updated 09/09/2015\n",
    "\n",
    "\n",
    "ORIGIN\n",
    "\n",
    "AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000  news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic community for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html .\n",
    "\n",
    "The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015).\n",
    "\n",
    "\n",
    "DESCRIPTION\n",
    "\n",
    "The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600.\n",
    "\n",
    "The file classes.txt contains a list of classes corresponding to each label.\n",
    "\n",
    "The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes (\"), and any internal double quote is escaped by 2 double quotes (\"\"). New lines are escaped by a backslash followed with an \"n\" character, that is \"\\n\".\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "moNmVfuvnImW"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.7.3\n",
      "IPython 7.9.0\n",
      "\n",
      "torch 1.3.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch\n",
    "\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torchtext import data\n",
    "from torchtext import datasets\n",
    "import time\n",
    "import random\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GSRL42Qgy8I8"
   },
   "source": [
    "## General Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "OvW1RgfepCBq"
   },
   "outputs": [],
   "source": [
    "RANDOM_SEED = 123\n",
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "VOCABULARY_SIZE = 5000\n",
    "LEARNING_RATE = 1e-3\n",
    "BATCH_SIZE = 128\n",
    "NUM_EPOCHS = 50\n",
    "DROPOUT = 0.5\n",
    "DEVICE = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "EMBEDDING_DIM = 128\n",
    "BIDIRECTIONAL = True\n",
    "HIDDEN_DIM = 256\n",
    "NUM_LAYERS = 2\n",
    "OUTPUT_DIM = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mQMmKUEisW4W"
   },
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The AG News dataset is available from Xiang Zhang's Google Drive folder at\n",
    "\n",
    "https://drive.google.com/drive/u/0/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M\n",
    "\n",
    "From the Google Drive folder, download the file \n",
    "\n",
    "- `ag_news_csv.tar.gz`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ag_news_csv/\n",
      "ag_news_csv/train.csv\n",
      "ag_news_csv/test.csv\n",
      "ag_news_csv/classes.txt\n",
      "ag_news_csv/readme.txt\n"
     ]
    }
   ],
   "source": [
    "!tar xvzf  ag_news_csv.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "World\n",
      "Sports\n",
      "Business\n",
      "Sci/Tech\n"
     ]
    }
   ],
   "source": [
    "!cat ag_news_csv/classes.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check that the dataset looks okay:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>classlabel</th>\n",
       "      <th>title</th>\n",
       "      <th>content</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>Wall St. Bears Claw Back Into the Black (Reuters)</td>\n",
       "      <td>Reuters - Short-sellers, Wall Street's dwindli...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>Carlyle Looks Toward Commercial Aerospace (Reu...</td>\n",
       "      <td>Reuters - Private investment firm Carlyle Grou...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>Oil and Economy Cloud Stocks' Outlook (Reuters)</td>\n",
       "      <td>Reuters - Soaring crude prices plus worries\\ab...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>Iraq Halts Oil Exports from Main Southern Pipe...</td>\n",
       "      <td>Reuters - Authorities have halted oil export\\f...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2</td>\n",
       "      <td>Oil prices soar to all-time record, posing new...</td>\n",
       "      <td>AFP - Tearaway world oil prices, toppling reco...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   classlabel                                              title  \\\n",
       "0           2  Wall St. Bears Claw Back Into the Black (Reuters)   \n",
       "1           2  Carlyle Looks Toward Commercial Aerospace (Reu...   \n",
       "2           2    Oil and Economy Cloud Stocks' Outlook (Reuters)   \n",
       "3           2  Iraq Halts Oil Exports from Main Southern Pipe...   \n",
       "4           2  Oil prices soar to all-time record, posing new...   \n",
       "\n",
       "                                             content  \n",
       "0  Reuters - Short-sellers, Wall Street's dwindli...  \n",
       "1  Reuters - Private investment firm Carlyle Grou...  \n",
       "2  Reuters - Soaring crude prices plus worries\\ab...  \n",
       "3  Reuters - Authorities have halted oil export\\f...  \n",
       "4  AFP - Tearaway world oil prices, toppling reco...  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('ag_news_csv/train.csv', header=None, index_col=None)\n",
    "df.columns = ['classlabel', 'title', 'content']\n",
    "df['classlabel'] = df['classlabel']-1\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(df['classlabel'].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([30000, 30000, 30000, 30000])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.bincount(df['classlabel'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[['classlabel', 'content']].to_csv('ag_news_csv/train_prepocessed.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>classlabel</th>\n",
       "      <th>title</th>\n",
       "      <th>content</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>Fears for T N pension after talks</td>\n",
       "      <td>Unions representing workers at Turner   Newall...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>The Race is On: Second Private Team Sets Launc...</td>\n",
       "      <td>SPACE.com - TORONTO, Canada -- A second\\team o...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>Ky. Company Wins Grant to Study Peptides (AP)</td>\n",
       "      <td>AP - A company founded by a chemistry research...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>Prediction Unit Helps Forecast Wildfires (AP)</td>\n",
       "      <td>AP - It's barely dawn when Mike Fitzpatrick st...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3</td>\n",
       "      <td>Calif. Aims to Limit Farm-Related Smog (AP)</td>\n",
       "      <td>AP - Southern California's smog-fighting agenc...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   classlabel                                              title  \\\n",
       "0           2                  Fears for T N pension after talks   \n",
       "1           3  The Race is On: Second Private Team Sets Launc...   \n",
       "2           3      Ky. Company Wins Grant to Study Peptides (AP)   \n",
       "3           3      Prediction Unit Helps Forecast Wildfires (AP)   \n",
       "4           3        Calif. Aims to Limit Farm-Related Smog (AP)   \n",
       "\n",
       "                                             content  \n",
       "0  Unions representing workers at Turner   Newall...  \n",
       "1  SPACE.com - TORONTO, Canada -- A second\\team o...  \n",
       "2  AP - A company founded by a chemistry research...  \n",
       "3  AP - It's barely dawn when Mike Fitzpatrick st...  \n",
       "4  AP - Southern California's smog-fighting agenc...  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('ag_news_csv/test.csv', header=None, index_col=None)\n",
    "df.columns = ['classlabel', 'title', 'content']\n",
    "df['classlabel'] = df['classlabel']-1\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(df['classlabel'].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1900, 1900, 1900, 1900])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.bincount(df['classlabel'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[['classlabel', 'content']].to_csv('ag_news_csv/test_prepocessed.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "del df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4GnH64XvsV8n"
   },
   "source": [
    "Define the Label and Text field formatters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEXT = data.Field(sequential=True,\n",
    "                  tokenize='spacy',\n",
    "                  include_lengths=True) # necessary for packed_padded_sequence\n",
    "\n",
    "LABEL = data.LabelField(dtype=torch.float)\n",
    "\n",
    "\n",
    "# If you get an error [E050] Can't find model 'en'\n",
    "# you need to run the following on your command line:\n",
    "#  python -m spacy download en"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Process the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "fields = [('classlabel', LABEL), ('content', TEXT)]\n",
    "\n",
    "train_dataset = data.TabularDataset(\n",
    "    path=\"ag_news_csv/train_prepocessed.csv\", format='csv',\n",
    "    skip_header=True, fields=fields)\n",
    "\n",
    "test_dataset = data.TabularDataset(\n",
    "    path=\"ag_news_csv/test_prepocessed.csv\", format='csv',\n",
    "    skip_header=True, fields=fields)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Split the training dataset into training and validation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "id": "WZ_4jiHVnMxN",
    "outputId": "dfa51c04-4845-44c3-f50b-d36d41f132b8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num Train: 114000\n",
      "Num Valid: 6000\n"
     ]
    }
   ],
   "source": [
    "train_data, valid_data = train_dataset.split(\n",
    "    split_ratio=[0.95, 0.05],\n",
    "    random_state=random.seed(RANDOM_SEED))\n",
    "\n",
    "print(f'Num Train: {len(train_data)}')\n",
    "print(f'Num Valid: {len(valid_data)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "L-TBwKWPslPa"
   },
   "source": [
    "Build the vocabulary based on the top \"VOCABULARY_SIZE\" words:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "e8uNrjdtn4A8",
    "outputId": "6cf499d7-7722-4da0-8576-ee0f218cc6e3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vocabulary size: 5002\n",
      "Number of classes: 4\n"
     ]
    }
   ],
   "source": [
    "TEXT.build_vocab(train_data,\n",
    "                 max_size=VOCABULARY_SIZE,\n",
    "                 vectors='glove.6B.100d',\n",
    "                 unk_init=torch.Tensor.normal_)\n",
    "\n",
    "LABEL.build_vocab(train_data)\n",
    "\n",
    "print(f'Vocabulary size: {len(TEXT.vocab)}')\n",
    "print(f'Number of classes: {len(LABEL.vocab)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['1', '3', '0', '2']"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(LABEL.vocab.freqs)[-10:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JpEMNInXtZsb"
   },
   "source": [
    "The TEXT.vocab dictionary will contain the word counts and indices. The reason why the number of words is VOCABULARY_SIZE + 2 is that it contains to special tokens for padding and unknown words: `<unk>` and `<pad>`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "eIQ_zfKLwjKm"
   },
   "source": [
    "Make dataset iterators:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "i7JiHR1stHNF"
   },
   "outputs": [],
   "source": [
    "train_loader, valid_loader, test_loader = data.BucketIterator.splits(\n",
    "    (train_data, valid_data, test_dataset), \n",
    "    batch_size=BATCH_SIZE,\n",
    "    sort_within_batch=True, # necessary for packed_padded_sequence\n",
    "    sort_key=lambda x: len(x.content),\n",
    "    device=DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "R0pT_dMRvicQ"
   },
   "source": [
    "Testing the iterators (note that the number of rows depends on the longest document in the respective batch):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 204
    },
    "colab_type": "code",
    "id": "y8SP_FccutT0",
    "outputId": "fe33763a-4560-4dee-adee-31cc6c48b0b2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train\n",
      "Text matrix size: torch.Size([35, 128])\n",
      "Target vector size: torch.Size([128])\n",
      "\n",
      "Valid:\n",
      "Text matrix size: torch.Size([17, 128])\n",
      "Target vector size: torch.Size([128])\n",
      "\n",
      "Test:\n",
      "Text matrix size: torch.Size([16, 128])\n",
      "Target vector size: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "print('Train')\n",
    "for batch in train_loader:\n",
    "    print(f'Text matrix size: {batch.content[0].size()}')\n",
    "    print(f'Target vector size: {batch.classlabel.size()}')\n",
    "    break\n",
    "    \n",
    "print('\\nValid:')\n",
    "for batch in valid_loader:\n",
    "    print(f'Text matrix size: {batch.content[0].size()}')\n",
    "    print(f'Target vector size: {batch.classlabel.size()}')\n",
    "    break\n",
    "    \n",
    "print('\\nTest:')\n",
    "for batch in test_loader:\n",
    "    print(f'Text matrix size: {batch.content[0].size()}')\n",
    "    print(f'Target vector size: {batch.classlabel.size()}')\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "G_grdW3pxCzz"
   },
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "nQIUm5EjxFNa"
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "class RNN(nn.Module):\n",
    "    def __init__(self, input_dim, embedding_dim, bidirectional, hidden_dim, num_layers, output_dim, dropout, pad_idx):\n",
    "        \n",
    "        super().__init__()\n",
    "        \n",
    "        self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx=pad_idx)\n",
    "        self.rnn = nn.LSTM(embedding_dim, \n",
    "                           hidden_dim,\n",
    "                           num_layers=num_layers,\n",
    "                           bidirectional=bidirectional, \n",
    "                           dropout=dropout)\n",
    "        self.fc1 = nn.Linear(hidden_dim * num_layers, 64)\n",
    "        self.fc2 = nn.Linear(64, output_dim)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, text, text_length):\n",
    "\n",
    "        embedded = self.dropout(self.embedding(text))\n",
    "        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_length)\n",
    "        packed_output, (hidden, cell) = self.rnn(packed_embedded)\n",
    "        output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)\n",
    "        hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))\n",
    "        hidden = self.fc1(hidden)\n",
    "        return hidden"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Ik3NF3faxFmZ"
   },
   "outputs": [],
   "source": [
    "INPUT_DIM = len(TEXT.vocab)\n",
    "\n",
    "PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n",
    "\n",
    "torch.manual_seed(RANDOM_SEED)\n",
    "model = RNN(INPUT_DIM, EMBEDDING_DIM, BIDIRECTIONAL, HIDDEN_DIM, NUM_LAYERS, OUTPUT_DIM, DROPOUT, PAD_IDX)\n",
    "model = model.to(DEVICE)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Lv9Ny9di6VcI"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "T5t1Afn4xO11"
   },
   "outputs": [],
   "source": [
    "def compute_accuracy(model, data_loader, device):\n",
    "    model.eval()\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, batch_data in enumerate(data_loader):\n",
    "            text, text_lengths = batch_data.content\n",
    "            logits = model(text, text_lengths).squeeze(1)\n",
    "            _, predicted_labels = torch.max(logits, 1)\n",
    "            num_examples += batch_data.classlabel.size(0)\n",
    "            correct_pred += (predicted_labels.long() == batch_data.classlabel.long()).sum()\n",
    "        return correct_pred.float()/num_examples * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1836
    },
    "colab_type": "code",
    "id": "EABZM8Vo0ilB",
    "outputId": "5d45e293-9909-4588-e793-8dfaf72e5c67"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/050 | Batch 000/891 | Cost: 4.1667\n",
      "Epoch: 001/050 | Batch 050/891 | Cost: 1.4755\n",
      "Epoch: 001/050 | Batch 100/891 | Cost: 1.3285\n",
      "Epoch: 001/050 | Batch 150/891 | Cost: 1.2829\n",
      "Epoch: 001/050 | Batch 200/891 | Cost: 1.1988\n",
      "Epoch: 001/050 | Batch 250/891 | Cost: 0.7590\n",
      "Epoch: 001/050 | Batch 300/891 | Cost: 0.9044\n",
      "Epoch: 001/050 | Batch 350/891 | Cost: 0.8458\n",
      "Epoch: 001/050 | Batch 400/891 | Cost: 0.6982\n",
      "Epoch: 001/050 | Batch 450/891 | Cost: 0.5888\n",
      "Epoch: 001/050 | Batch 500/891 | Cost: 0.5934\n",
      "Epoch: 001/050 | Batch 550/891 | Cost: 0.6286\n",
      "Epoch: 001/050 | Batch 600/891 | Cost: 0.5483\n",
      "Epoch: 001/050 | Batch 650/891 | Cost: 0.6919\n",
      "Epoch: 001/050 | Batch 700/891 | Cost: 0.3992\n",
      "Epoch: 001/050 | Batch 750/891 | Cost: 0.8836\n",
      "Epoch: 001/050 | Batch 800/891 | Cost: 0.5266\n",
      "Epoch: 001/050 | Batch 850/891 | Cost: 0.5314\n",
      "training accuracy: 84.63%\n",
      "valid accuracy: 83.77%\n",
      "Time elapsed: 0.46 min\n",
      "Epoch: 002/050 | Batch 000/891 | Cost: 0.3907\n",
      "Epoch: 002/050 | Batch 050/891 | Cost: 0.3792\n",
      "Epoch: 002/050 | Batch 100/891 | Cost: 0.4874\n",
      "Epoch: 002/050 | Batch 150/891 | Cost: 0.5823\n",
      "Epoch: 002/050 | Batch 200/891 | Cost: 0.4622\n",
      "Epoch: 002/050 | Batch 250/891 | Cost: 0.3818\n",
      "Epoch: 002/050 | Batch 300/891 | Cost: 0.4743\n",
      "Epoch: 002/050 | Batch 350/891 | Cost: 0.5085\n",
      "Epoch: 002/050 | Batch 400/891 | Cost: 0.4229\n",
      "Epoch: 002/050 | Batch 450/891 | Cost: 0.3666\n",
      "Epoch: 002/050 | Batch 500/891 | Cost: 0.3102\n",
      "Epoch: 002/050 | Batch 550/891 | Cost: 0.4300\n",
      "Epoch: 002/050 | Batch 600/891 | Cost: 0.6906\n",
      "Epoch: 002/050 | Batch 650/891 | Cost: 0.3315\n",
      "Epoch: 002/050 | Batch 700/891 | Cost: 0.4410\n",
      "Epoch: 002/050 | Batch 750/891 | Cost: 0.3719\n",
      "Epoch: 002/050 | Batch 800/891 | Cost: 0.4229\n",
      "Epoch: 002/050 | Batch 850/891 | Cost: 0.5765\n",
      "training accuracy: 89.51%\n",
      "valid accuracy: 88.93%\n",
      "Time elapsed: 0.93 min\n",
      "Epoch: 003/050 | Batch 000/891 | Cost: 0.4050\n",
      "Epoch: 003/050 | Batch 050/891 | Cost: 0.3719\n",
      "Epoch: 003/050 | Batch 100/891 | Cost: 0.3914\n",
      "Epoch: 003/050 | Batch 150/891 | Cost: 0.2547\n",
      "Epoch: 003/050 | Batch 200/891 | Cost: 0.2478\n",
      "Epoch: 003/050 | Batch 250/891 | Cost: 0.6579\n",
      "Epoch: 003/050 | Batch 300/891 | Cost: 0.3390\n",
      "Epoch: 003/050 | Batch 350/891 | Cost: 0.4368\n",
      "Epoch: 003/050 | Batch 400/891 | Cost: 0.3960\n",
      "Epoch: 003/050 | Batch 450/891 | Cost: 0.2799\n",
      "Epoch: 003/050 | Batch 500/891 | Cost: 0.2862\n",
      "Epoch: 003/050 | Batch 550/891 | Cost: 0.3342\n",
      "Epoch: 003/050 | Batch 600/891 | Cost: 0.2348\n",
      "Epoch: 003/050 | Batch 650/891 | Cost: 0.3088\n",
      "Epoch: 003/050 | Batch 700/891 | Cost: 0.7425\n",
      "Epoch: 003/050 | Batch 750/891 | Cost: 0.2534\n",
      "Epoch: 003/050 | Batch 800/891 | Cost: 0.3224\n",
      "Epoch: 003/050 | Batch 850/891 | Cost: 0.2275\n",
      "training accuracy: 90.49%\n",
      "valid accuracy: 89.32%\n",
      "Time elapsed: 1.40 min\n",
      "Epoch: 004/050 | Batch 000/891 | Cost: 0.2450\n",
      "Epoch: 004/050 | Batch 050/891 | Cost: 0.2518\n",
      "Epoch: 004/050 | Batch 100/891 | Cost: 0.6905\n",
      "Epoch: 004/050 | Batch 150/891 | Cost: 0.3877\n",
      "Epoch: 004/050 | Batch 200/891 | Cost: 0.2438\n",
      "Epoch: 004/050 | Batch 250/891 | Cost: 0.2047\n",
      "Epoch: 004/050 | Batch 300/891 | Cost: 0.2984\n",
      "Epoch: 004/050 | Batch 350/891 | Cost: 0.4487\n",
      "Epoch: 004/050 | Batch 400/891 | Cost: 0.2900\n",
      "Epoch: 004/050 | Batch 450/891 | Cost: 0.2992\n",
      "Epoch: 004/050 | Batch 500/891 | Cost: 0.2952\n",
      "Epoch: 004/050 | Batch 550/891 | Cost: 0.2289\n",
      "Epoch: 004/050 | Batch 600/891 | Cost: 0.2467\n",
      "Epoch: 004/050 | Batch 650/891 | Cost: 0.1343\n",
      "Epoch: 004/050 | Batch 700/891 | Cost: 0.2538\n",
      "Epoch: 004/050 | Batch 750/891 | Cost: 0.3580\n",
      "Epoch: 004/050 | Batch 800/891 | Cost: 0.3781\n",
      "Epoch: 004/050 | Batch 850/891 | Cost: 0.2254\n",
      "training accuracy: 91.43%\n",
      "valid accuracy: 90.47%\n",
      "Time elapsed: 1.86 min\n",
      "Epoch: 005/050 | Batch 000/891 | Cost: 0.4273\n",
      "Epoch: 005/050 | Batch 050/891 | Cost: 0.3250\n",
      "Epoch: 005/050 | Batch 100/891 | Cost: 0.4769\n",
      "Epoch: 005/050 | Batch 150/891 | Cost: 0.3298\n",
      "Epoch: 005/050 | Batch 200/891 | Cost: 0.3183\n",
      "Epoch: 005/050 | Batch 250/891 | Cost: 0.2533\n",
      "Epoch: 005/050 | Batch 300/891 | Cost: 0.2897\n",
      "Epoch: 005/050 | Batch 350/891 | Cost: 0.2772\n",
      "Epoch: 005/050 | Batch 400/891 | Cost: 0.3040\n",
      "Epoch: 005/050 | Batch 450/891 | Cost: 0.2332\n",
      "Epoch: 005/050 | Batch 500/891 | Cost: 0.2608\n",
      "Epoch: 005/050 | Batch 550/891 | Cost: 0.2563\n",
      "Epoch: 005/050 | Batch 600/891 | Cost: 0.3264\n",
      "Epoch: 005/050 | Batch 650/891 | Cost: 0.2695\n",
      "Epoch: 005/050 | Batch 700/891 | Cost: 0.4137\n",
      "Epoch: 005/050 | Batch 750/891 | Cost: 0.2787\n",
      "Epoch: 005/050 | Batch 800/891 | Cost: 0.3102\n",
      "Epoch: 005/050 | Batch 850/891 | Cost: 0.2707\n",
      "training accuracy: 92.33%\n",
      "valid accuracy: 90.92%\n",
      "Time elapsed: 2.34 min\n",
      "Epoch: 006/050 | Batch 000/891 | Cost: 0.5286\n",
      "Epoch: 006/050 | Batch 050/891 | Cost: 0.1996\n",
      "Epoch: 006/050 | Batch 100/891 | Cost: 0.3859\n",
      "Epoch: 006/050 | Batch 150/891 | Cost: 0.2322\n",
      "Epoch: 006/050 | Batch 200/891 | Cost: 0.2821\n",
      "Epoch: 006/050 | Batch 250/891 | Cost: 0.3530\n",
      "Epoch: 006/050 | Batch 300/891 | Cost: 0.3880\n",
      "Epoch: 006/050 | Batch 350/891 | Cost: 0.4259\n",
      "Epoch: 006/050 | Batch 400/891 | Cost: 0.3522\n",
      "Epoch: 006/050 | Batch 450/891 | Cost: 0.3299\n",
      "Epoch: 006/050 | Batch 500/891 | Cost: 0.3318\n",
      "Epoch: 006/050 | Batch 550/891 | Cost: 0.3139\n",
      "Epoch: 006/050 | Batch 600/891 | Cost: 0.2604\n",
      "Epoch: 006/050 | Batch 650/891 | Cost: 0.2049\n",
      "Epoch: 006/050 | Batch 700/891 | Cost: 0.2948\n",
      "Epoch: 006/050 | Batch 750/891 | Cost: 0.2000\n",
      "Epoch: 006/050 | Batch 800/891 | Cost: 0.1694\n",
      "Epoch: 006/050 | Batch 850/891 | Cost: 0.3553\n",
      "training accuracy: 92.73%\n",
      "valid accuracy: 91.20%\n",
      "Time elapsed: 2.81 min\n",
      "Epoch: 007/050 | Batch 000/891 | Cost: 0.2787\n",
      "Epoch: 007/050 | Batch 050/891 | Cost: 0.2766\n",
      "Epoch: 007/050 | Batch 100/891 | Cost: 0.3586\n",
      "Epoch: 007/050 | Batch 150/891 | Cost: 0.2167\n",
      "Epoch: 007/050 | Batch 200/891 | Cost: 0.2809\n",
      "Epoch: 007/050 | Batch 250/891 | Cost: 0.1589\n",
      "Epoch: 007/050 | Batch 300/891 | Cost: 0.2980\n",
      "Epoch: 007/050 | Batch 350/891 | Cost: 0.2061\n",
      "Epoch: 007/050 | Batch 400/891 | Cost: 0.2757\n",
      "Epoch: 007/050 | Batch 450/891 | Cost: 0.2706\n",
      "Epoch: 007/050 | Batch 500/891 | Cost: 0.1621\n",
      "Epoch: 007/050 | Batch 550/891 | Cost: 0.2763\n",
      "Epoch: 007/050 | Batch 600/891 | Cost: 0.2122\n",
      "Epoch: 007/050 | Batch 650/891 | Cost: 0.3193\n",
      "Epoch: 007/050 | Batch 700/891 | Cost: 0.3161\n",
      "Epoch: 007/050 | Batch 750/891 | Cost: 0.5697\n",
      "Epoch: 007/050 | Batch 800/891 | Cost: 0.2462\n",
      "Epoch: 007/050 | Batch 850/891 | Cost: 0.4072\n",
      "training accuracy: 93.18%\n",
      "valid accuracy: 91.13%\n",
      "Time elapsed: 3.29 min\n",
      "Epoch: 008/050 | Batch 000/891 | Cost: 0.1531\n",
      "Epoch: 008/050 | Batch 050/891 | Cost: 0.2815\n",
      "Epoch: 008/050 | Batch 100/891 | Cost: 0.1890\n",
      "Epoch: 008/050 | Batch 150/891 | Cost: 0.3430\n",
      "Epoch: 008/050 | Batch 200/891 | Cost: 0.3179\n",
      "Epoch: 008/050 | Batch 250/891 | Cost: 0.1990\n",
      "Epoch: 008/050 | Batch 300/891 | Cost: 0.2313\n",
      "Epoch: 008/050 | Batch 350/891 | Cost: 0.1431\n",
      "Epoch: 008/050 | Batch 400/891 | Cost: 0.1857\n",
      "Epoch: 008/050 | Batch 450/891 | Cost: 0.3604\n",
      "Epoch: 008/050 | Batch 500/891 | Cost: 0.3531\n",
      "Epoch: 008/050 | Batch 550/891 | Cost: 0.2136\n",
      "Epoch: 008/050 | Batch 600/891 | Cost: 0.3887\n",
      "Epoch: 008/050 | Batch 650/891 | Cost: 0.2011\n",
      "Epoch: 008/050 | Batch 700/891 | Cost: 0.1803\n",
      "Epoch: 008/050 | Batch 750/891 | Cost: 0.3328\n",
      "Epoch: 008/050 | Batch 800/891 | Cost: 0.2284\n",
      "Epoch: 008/050 | Batch 850/891 | Cost: 0.1928\n",
      "training accuracy: 93.41%\n",
      "valid accuracy: 91.22%\n",
      "Time elapsed: 3.77 min\n",
      "Epoch: 009/050 | Batch 000/891 | Cost: 0.2629\n",
      "Epoch: 009/050 | Batch 050/891 | Cost: 0.2781\n",
      "Epoch: 009/050 | Batch 100/891 | Cost: 0.2318\n",
      "Epoch: 009/050 | Batch 150/891 | Cost: 0.2701\n",
      "Epoch: 009/050 | Batch 200/891 | Cost: 0.1944\n",
      "Epoch: 009/050 | Batch 250/891 | Cost: 0.3229\n",
      "Epoch: 009/050 | Batch 300/891 | Cost: 0.2979\n",
      "Epoch: 009/050 | Batch 350/891 | Cost: 0.2095\n",
      "Epoch: 009/050 | Batch 400/891 | Cost: 0.1358\n",
      "Epoch: 009/050 | Batch 450/891 | Cost: 0.2221\n",
      "Epoch: 009/050 | Batch 500/891 | Cost: 0.1896\n",
      "Epoch: 009/050 | Batch 550/891 | Cost: 0.2059\n",
      "Epoch: 009/050 | Batch 600/891 | Cost: 0.2914\n",
      "Epoch: 009/050 | Batch 650/891 | Cost: 0.4117\n",
      "Epoch: 009/050 | Batch 700/891 | Cost: 0.2545\n",
      "Epoch: 009/050 | Batch 750/891 | Cost: 0.3608\n",
      "Epoch: 009/050 | Batch 800/891 | Cost: 0.2593\n",
      "Epoch: 009/050 | Batch 850/891 | Cost: 0.1308\n",
      "training accuracy: 93.86%\n",
      "valid accuracy: 91.33%\n",
      "Time elapsed: 4.26 min\n",
      "Epoch: 010/050 | Batch 000/891 | Cost: 0.1087\n",
      "Epoch: 010/050 | Batch 050/891 | Cost: 0.1921\n",
      "Epoch: 010/050 | Batch 100/891 | Cost: 0.1257\n",
      "Epoch: 010/050 | Batch 150/891 | Cost: 0.4087\n",
      "Epoch: 010/050 | Batch 200/891 | Cost: 0.2603\n",
      "Epoch: 010/050 | Batch 250/891 | Cost: 0.1607\n",
      "Epoch: 010/050 | Batch 300/891 | Cost: 0.2791\n",
      "Epoch: 010/050 | Batch 350/891 | Cost: 0.1774\n",
      "Epoch: 010/050 | Batch 400/891 | Cost: 0.5015\n",
      "Epoch: 010/050 | Batch 450/891 | Cost: 0.2276\n",
      "Epoch: 010/050 | Batch 500/891 | Cost: 0.2954\n",
      "Epoch: 010/050 | Batch 550/891 | Cost: 0.1906\n",
      "Epoch: 010/050 | Batch 600/891 | Cost: 0.2464\n",
      "Epoch: 010/050 | Batch 650/891 | Cost: 0.2425\n",
      "Epoch: 010/050 | Batch 700/891 | Cost: 0.2000\n",
      "Epoch: 010/050 | Batch 750/891 | Cost: 0.2981\n",
      "Epoch: 010/050 | Batch 800/891 | Cost: 0.2060\n",
      "Epoch: 010/050 | Batch 850/891 | Cost: 0.2032\n",
      "training accuracy: 94.30%\n",
      "valid accuracy: 91.70%\n",
      "Time elapsed: 4.74 min\n",
      "Epoch: 011/050 | Batch 000/891 | Cost: 0.2229\n",
      "Epoch: 011/050 | Batch 050/891 | Cost: 0.2725\n",
      "Epoch: 011/050 | Batch 100/891 | Cost: 0.1801\n",
      "Epoch: 011/050 | Batch 150/891 | Cost: 0.2125\n",
      "Epoch: 011/050 | Batch 200/891 | Cost: 0.1482\n",
      "Epoch: 011/050 | Batch 250/891 | Cost: 0.2237\n",
      "Epoch: 011/050 | Batch 300/891 | Cost: 0.1581\n",
      "Epoch: 011/050 | Batch 350/891 | Cost: 0.3981\n",
      "Epoch: 011/050 | Batch 400/891 | Cost: 0.2683\n",
      "Epoch: 011/050 | Batch 450/891 | Cost: 0.2471\n",
      "Epoch: 011/050 | Batch 500/891 | Cost: 0.1495\n",
      "Epoch: 011/050 | Batch 550/891 | Cost: 0.2281\n",
      "Epoch: 011/050 | Batch 600/891 | Cost: 0.2023\n",
      "Epoch: 011/050 | Batch 650/891 | Cost: 0.1069\n",
      "Epoch: 011/050 | Batch 700/891 | Cost: 0.1906\n",
      "Epoch: 011/050 | Batch 750/891 | Cost: 0.2770\n",
      "Epoch: 011/050 | Batch 800/891 | Cost: 0.1736\n",
      "Epoch: 011/050 | Batch 850/891 | Cost: 0.1480\n",
      "training accuracy: 94.57%\n",
      "valid accuracy: 91.77%\n",
      "Time elapsed: 5.23 min\n",
      "Epoch: 012/050 | Batch 000/891 | Cost: 0.1419\n",
      "Epoch: 012/050 | Batch 050/891 | Cost: 0.2082\n",
      "Epoch: 012/050 | Batch 100/891 | Cost: 0.1527\n",
      "Epoch: 012/050 | Batch 150/891 | Cost: 0.1564\n",
      "Epoch: 012/050 | Batch 200/891 | Cost: 0.2391\n",
      "Epoch: 012/050 | Batch 250/891 | Cost: 0.3568\n",
      "Epoch: 012/050 | Batch 300/891 | Cost: 0.0926\n",
      "Epoch: 012/050 | Batch 350/891 | Cost: 0.1798\n",
      "Epoch: 012/050 | Batch 400/891 | Cost: 0.2591\n",
      "Epoch: 012/050 | Batch 450/891 | Cost: 0.2005\n",
      "Epoch: 012/050 | Batch 500/891 | Cost: 0.1461\n",
      "Epoch: 012/050 | Batch 550/891 | Cost: 0.2099\n",
      "Epoch: 012/050 | Batch 600/891 | Cost: 0.1473\n",
      "Epoch: 012/050 | Batch 650/891 | Cost: 0.2052\n",
      "Epoch: 012/050 | Batch 700/891 | Cost: 0.2090\n",
      "Epoch: 012/050 | Batch 750/891 | Cost: 0.3133\n",
      "Epoch: 012/050 | Batch 800/891 | Cost: 0.0936\n",
      "Epoch: 012/050 | Batch 850/891 | Cost: 0.1964\n",
      "training accuracy: 94.91%\n",
      "valid accuracy: 91.92%\n",
      "Time elapsed: 5.71 min\n",
      "Epoch: 013/050 | Batch 000/891 | Cost: 0.1882\n",
      "Epoch: 013/050 | Batch 050/891 | Cost: 0.1726\n",
      "Epoch: 013/050 | Batch 100/891 | Cost: 0.2273\n",
      "Epoch: 013/050 | Batch 150/891 | Cost: 0.4143\n",
      "Epoch: 013/050 | Batch 200/891 | Cost: 0.1912\n",
      "Epoch: 013/050 | Batch 250/891 | Cost: 0.1610\n",
      "Epoch: 013/050 | Batch 300/891 | Cost: 0.2238\n",
      "Epoch: 013/050 | Batch 350/891 | Cost: 0.3671\n",
      "Epoch: 013/050 | Batch 400/891 | Cost: 0.1471\n",
      "Epoch: 013/050 | Batch 450/891 | Cost: 0.2440\n",
      "Epoch: 013/050 | Batch 500/891 | Cost: 0.2701\n",
      "Epoch: 013/050 | Batch 550/891 | Cost: 0.2684\n",
      "Epoch: 013/050 | Batch 600/891 | Cost: 0.1602\n",
      "Epoch: 013/050 | Batch 650/891 | Cost: 0.2128\n",
      "Epoch: 013/050 | Batch 700/891 | Cost: 0.0978\n",
      "Epoch: 013/050 | Batch 750/891 | Cost: 0.2017\n",
      "Epoch: 013/050 | Batch 800/891 | Cost: 0.0781\n",
      "Epoch: 013/050 | Batch 850/891 | Cost: 0.2742\n",
      "training accuracy: 95.13%\n",
      "valid accuracy: 91.83%\n",
      "Time elapsed: 6.19 min\n",
      "Epoch: 014/050 | Batch 000/891 | Cost: 0.1590\n",
      "Epoch: 014/050 | Batch 050/891 | Cost: 0.1685\n",
      "Epoch: 014/050 | Batch 100/891 | Cost: 0.2997\n",
      "Epoch: 014/050 | Batch 150/891 | Cost: 0.0779\n",
      "Epoch: 014/050 | Batch 200/891 | Cost: 0.1422\n",
      "Epoch: 014/050 | Batch 250/891 | Cost: 0.2610\n",
      "Epoch: 014/050 | Batch 300/891 | Cost: 0.2471\n",
      "Epoch: 014/050 | Batch 350/891 | Cost: 0.1126\n",
      "Epoch: 014/050 | Batch 400/891 | Cost: 0.5214\n",
      "Epoch: 014/050 | Batch 450/891 | Cost: 0.1805\n",
      "Epoch: 014/050 | Batch 500/891 | Cost: 0.3690\n",
      "Epoch: 014/050 | Batch 550/891 | Cost: 0.1889\n",
      "Epoch: 014/050 | Batch 600/891 | Cost: 0.2583\n",
      "Epoch: 014/050 | Batch 650/891 | Cost: 0.1955\n",
      "Epoch: 014/050 | Batch 700/891 | Cost: 0.3968\n",
      "Epoch: 014/050 | Batch 750/891 | Cost: 0.4153\n",
      "Epoch: 014/050 | Batch 800/891 | Cost: 0.2386\n",
      "Epoch: 014/050 | Batch 850/891 | Cost: 0.2618\n",
      "training accuracy: 95.26%\n",
      "valid accuracy: 91.82%\n",
      "Time elapsed: 6.68 min\n",
      "Epoch: 015/050 | Batch 000/891 | Cost: 0.2095\n",
      "Epoch: 015/050 | Batch 050/891 | Cost: 0.1248\n",
      "Epoch: 015/050 | Batch 100/891 | Cost: 0.2129\n",
      "Epoch: 015/050 | Batch 150/891 | Cost: 0.1529\n",
      "Epoch: 015/050 | Batch 200/891 | Cost: 0.1211\n",
      "Epoch: 015/050 | Batch 250/891 | Cost: 0.2485\n",
      "Epoch: 015/050 | Batch 300/891 | Cost: 0.1596\n",
      "Epoch: 015/050 | Batch 350/891 | Cost: 0.2131\n",
      "Epoch: 015/050 | Batch 400/891 | Cost: 0.2655\n",
      "Epoch: 015/050 | Batch 450/891 | Cost: 0.1532\n",
      "Epoch: 015/050 | Batch 500/891 | Cost: 0.1442\n",
      "Epoch: 015/050 | Batch 550/891 | Cost: 0.2170\n",
      "Epoch: 015/050 | Batch 600/891 | Cost: 0.2097\n",
      "Epoch: 015/050 | Batch 650/891 | Cost: 0.1731\n",
      "Epoch: 015/050 | Batch 700/891 | Cost: 0.2049\n",
      "Epoch: 015/050 | Batch 750/891 | Cost: 0.1335\n",
      "Epoch: 015/050 | Batch 800/891 | Cost: 0.1869\n",
      "Epoch: 015/050 | Batch 850/891 | Cost: 0.1313\n",
      "training accuracy: 95.62%\n",
      "valid accuracy: 91.93%\n",
      "Time elapsed: 7.17 min\n",
      "Epoch: 016/050 | Batch 000/891 | Cost: 0.2243\n",
      "Epoch: 016/050 | Batch 050/891 | Cost: 0.1787\n",
      "Epoch: 016/050 | Batch 100/891 | Cost: 0.0720\n",
      "Epoch: 016/050 | Batch 150/891 | Cost: 0.1693\n",
      "Epoch: 016/050 | Batch 200/891 | Cost: 0.0990\n",
      "Epoch: 016/050 | Batch 250/891 | Cost: 0.2836\n",
      "Epoch: 016/050 | Batch 300/891 | Cost: 0.1295\n",
      "Epoch: 016/050 | Batch 350/891 | Cost: 0.0999\n",
      "Epoch: 016/050 | Batch 400/891 | Cost: 0.1612\n",
      "Epoch: 016/050 | Batch 450/891 | Cost: 0.2436\n",
      "Epoch: 016/050 | Batch 500/891 | Cost: 0.2344\n",
      "Epoch: 016/050 | Batch 550/891 | Cost: 0.2931\n",
      "Epoch: 016/050 | Batch 600/891 | Cost: 0.0864\n",
      "Epoch: 016/050 | Batch 650/891 | Cost: 0.2007\n",
      "Epoch: 016/050 | Batch 700/891 | Cost: 0.1101\n",
      "Epoch: 016/050 | Batch 750/891 | Cost: 0.2093\n",
      "Epoch: 016/050 | Batch 800/891 | Cost: 0.1148\n",
      "Epoch: 016/050 | Batch 850/891 | Cost: 0.1621\n",
      "training accuracy: 95.66%\n",
      "valid accuracy: 91.53%\n",
      "Time elapsed: 7.65 min\n",
      "Epoch: 017/050 | Batch 000/891 | Cost: 0.3486\n",
      "Epoch: 017/050 | Batch 050/891 | Cost: 0.1839\n",
      "Epoch: 017/050 | Batch 100/891 | Cost: 0.0831\n",
      "Epoch: 017/050 | Batch 150/891 | Cost: 0.1529\n",
      "Epoch: 017/050 | Batch 200/891 | Cost: 0.2675\n",
      "Epoch: 017/050 | Batch 250/891 | Cost: 0.1468\n",
      "Epoch: 017/050 | Batch 300/891 | Cost: 0.1797\n",
      "Epoch: 017/050 | Batch 350/891 | Cost: 0.1586\n",
      "Epoch: 017/050 | Batch 400/891 | Cost: 0.1128\n",
      "Epoch: 017/050 | Batch 450/891 | Cost: 0.1678\n",
      "Epoch: 017/050 | Batch 500/891 | Cost: 0.1740\n",
      "Epoch: 017/050 | Batch 550/891 | Cost: 0.2684\n",
      "Epoch: 017/050 | Batch 600/891 | Cost: 0.1596\n",
      "Epoch: 017/050 | Batch 650/891 | Cost: 0.2647\n",
      "Epoch: 017/050 | Batch 700/891 | Cost: 0.1738\n",
      "Epoch: 017/050 | Batch 750/891 | Cost: 0.2119\n",
      "Epoch: 017/050 | Batch 800/891 | Cost: 0.1385\n",
      "Epoch: 017/050 | Batch 850/891 | Cost: 0.1648\n",
      "training accuracy: 95.90%\n",
      "valid accuracy: 91.77%\n",
      "Time elapsed: 8.14 min\n",
      "Epoch: 018/050 | Batch 000/891 | Cost: 0.1678\n",
      "Epoch: 018/050 | Batch 050/891 | Cost: 0.1260\n",
      "Epoch: 018/050 | Batch 100/891 | Cost: 0.1912\n",
      "Epoch: 018/050 | Batch 150/891 | Cost: 0.1299\n",
      "Epoch: 018/050 | Batch 200/891 | Cost: 0.1702\n",
      "Epoch: 018/050 | Batch 250/891 | Cost: 0.1456\n",
      "Epoch: 018/050 | Batch 300/891 | Cost: 0.1284\n",
      "Epoch: 018/050 | Batch 350/891 | Cost: 0.2763\n",
      "Epoch: 018/050 | Batch 400/891 | Cost: 0.0950\n",
      "Epoch: 018/050 | Batch 450/891 | Cost: 0.1417\n",
      "Epoch: 018/050 | Batch 500/891 | Cost: 0.2453\n",
      "Epoch: 018/050 | Batch 550/891 | Cost: 0.2603\n",
      "Epoch: 018/050 | Batch 600/891 | Cost: 0.2635\n",
      "Epoch: 018/050 | Batch 650/891 | Cost: 0.1849\n",
      "Epoch: 018/050 | Batch 700/891 | Cost: 0.1742\n",
      "Epoch: 018/050 | Batch 750/891 | Cost: 0.1185\n",
      "Epoch: 018/050 | Batch 800/891 | Cost: 0.4024\n",
      "Epoch: 018/050 | Batch 850/891 | Cost: 0.1221\n",
      "training accuracy: 96.03%\n",
      "valid accuracy: 91.83%\n",
      "Time elapsed: 8.63 min\n",
      "Epoch: 019/050 | Batch 000/891 | Cost: 0.1801\n",
      "Epoch: 019/050 | Batch 050/891 | Cost: 0.2904\n",
      "Epoch: 019/050 | Batch 100/891 | Cost: 0.1423\n",
      "Epoch: 019/050 | Batch 150/891 | Cost: 0.2176\n",
      "Epoch: 019/050 | Batch 200/891 | Cost: 0.2692\n",
      "Epoch: 019/050 | Batch 250/891 | Cost: 0.1769\n",
      "Epoch: 019/050 | Batch 300/891 | Cost: 0.1792\n",
      "Epoch: 019/050 | Batch 350/891 | Cost: 0.4244\n",
      "Epoch: 019/050 | Batch 400/891 | Cost: 0.1208\n",
      "Epoch: 019/050 | Batch 450/891 | Cost: 0.3000\n",
      "Epoch: 019/050 | Batch 500/891 | Cost: 0.1977\n",
      "Epoch: 019/050 | Batch 550/891 | Cost: 0.2125\n",
      "Epoch: 019/050 | Batch 600/891 | Cost: 0.1181\n",
      "Epoch: 019/050 | Batch 650/891 | Cost: 0.1804\n",
      "Epoch: 019/050 | Batch 700/891 | Cost: 0.1098\n",
      "Epoch: 019/050 | Batch 750/891 | Cost: 0.2638\n",
      "Epoch: 019/050 | Batch 800/891 | Cost: 0.1524\n",
      "Epoch: 019/050 | Batch 850/891 | Cost: 0.2061\n",
      "training accuracy: 96.25%\n",
      "valid accuracy: 91.98%\n",
      "Time elapsed: 9.11 min\n",
      "Epoch: 020/050 | Batch 000/891 | Cost: 0.1519\n",
      "Epoch: 020/050 | Batch 050/891 | Cost: 0.1323\n",
      "Epoch: 020/050 | Batch 100/891 | Cost: 0.1637\n",
      "Epoch: 020/050 | Batch 150/891 | Cost: 0.2232\n",
      "Epoch: 020/050 | Batch 200/891 | Cost: 0.4358\n",
      "Epoch: 020/050 | Batch 250/891 | Cost: 0.1855\n",
      "Epoch: 020/050 | Batch 300/891 | Cost: 0.2004\n",
      "Epoch: 020/050 | Batch 350/891 | Cost: 0.0560\n",
      "Epoch: 020/050 | Batch 400/891 | Cost: 0.0841\n",
      "Epoch: 020/050 | Batch 450/891 | Cost: 0.0601\n",
      "Epoch: 020/050 | Batch 500/891 | Cost: 0.0987\n",
      "Epoch: 020/050 | Batch 550/891 | Cost: 0.1021\n",
      "Epoch: 020/050 | Batch 600/891 | Cost: 0.4316\n",
      "Epoch: 020/050 | Batch 650/891 | Cost: 0.1060\n",
      "Epoch: 020/050 | Batch 700/891 | Cost: 0.1655\n",
      "Epoch: 020/050 | Batch 750/891 | Cost: 0.1303\n",
      "Epoch: 020/050 | Batch 800/891 | Cost: 0.2889\n",
      "Epoch: 020/050 | Batch 850/891 | Cost: 0.0948\n",
      "training accuracy: 96.46%\n",
      "valid accuracy: 91.87%\n",
      "Time elapsed: 9.59 min\n",
      "Epoch: 021/050 | Batch 000/891 | Cost: 0.1513\n",
      "Epoch: 021/050 | Batch 050/891 | Cost: 0.1063\n",
      "Epoch: 021/050 | Batch 100/891 | Cost: 0.0808\n",
      "Epoch: 021/050 | Batch 150/891 | Cost: 0.1390\n",
      "Epoch: 021/050 | Batch 200/891 | Cost: 0.1452\n",
      "Epoch: 021/050 | Batch 250/891 | Cost: 0.2100\n",
      "Epoch: 021/050 | Batch 300/891 | Cost: 0.1803\n",
      "Epoch: 021/050 | Batch 350/891 | Cost: 0.1057\n",
      "Epoch: 021/050 | Batch 400/891 | Cost: 0.1293\n",
      "Epoch: 021/050 | Batch 450/891 | Cost: 0.1064\n",
      "Epoch: 021/050 | Batch 500/891 | Cost: 0.1383\n",
      "Epoch: 021/050 | Batch 550/891 | Cost: 0.1331\n",
      "Epoch: 021/050 | Batch 600/891 | Cost: 0.2483\n",
      "Epoch: 021/050 | Batch 650/891 | Cost: 0.1053\n",
      "Epoch: 021/050 | Batch 700/891 | Cost: 0.0852\n",
      "Epoch: 021/050 | Batch 750/891 | Cost: 0.0939\n",
      "Epoch: 021/050 | Batch 800/891 | Cost: 0.1492\n",
      "Epoch: 021/050 | Batch 850/891 | Cost: 0.1075\n",
      "training accuracy: 96.56%\n",
      "valid accuracy: 92.13%\n",
      "Time elapsed: 10.33 min\n",
      "Epoch: 022/050 | Batch 000/891 | Cost: 0.1810\n",
      "Epoch: 022/050 | Batch 050/891 | Cost: 0.1069\n",
      "Epoch: 022/050 | Batch 100/891 | Cost: 0.1601\n",
      "Epoch: 022/050 | Batch 150/891 | Cost: 0.1092\n",
      "Epoch: 022/050 | Batch 200/891 | Cost: 0.2255\n",
      "Epoch: 022/050 | Batch 250/891 | Cost: 0.3778\n",
      "Epoch: 022/050 | Batch 300/891 | Cost: 0.1875\n",
      "Epoch: 022/050 | Batch 350/891 | Cost: 0.1854\n",
      "Epoch: 022/050 | Batch 400/891 | Cost: 0.3620\n",
      "Epoch: 022/050 | Batch 450/891 | Cost: 0.1210\n",
      "Epoch: 022/050 | Batch 500/891 | Cost: 0.0647\n",
      "Epoch: 022/050 | Batch 550/891 | Cost: 0.2215\n",
      "Epoch: 022/050 | Batch 600/891 | Cost: 0.1141\n",
      "Epoch: 022/050 | Batch 650/891 | Cost: 0.1765\n",
      "Epoch: 022/050 | Batch 700/891 | Cost: 0.1067\n",
      "Epoch: 022/050 | Batch 750/891 | Cost: 0.1907\n",
      "Epoch: 022/050 | Batch 800/891 | Cost: 0.1374\n",
      "Epoch: 022/050 | Batch 850/891 | Cost: 0.1366\n",
      "training accuracy: 96.75%\n",
      "valid accuracy: 91.88%\n",
      "Time elapsed: 11.35 min\n",
      "Epoch: 023/050 | Batch 000/891 | Cost: 0.0993\n",
      "Epoch: 023/050 | Batch 050/891 | Cost: 0.1212\n",
      "Epoch: 023/050 | Batch 100/891 | Cost: 0.1991\n",
      "Epoch: 023/050 | Batch 150/891 | Cost: 0.2732\n",
      "Epoch: 023/050 | Batch 200/891 | Cost: 0.2020\n",
      "Epoch: 023/050 | Batch 250/891 | Cost: 0.0996\n",
      "Epoch: 023/050 | Batch 300/891 | Cost: 0.2931\n",
      "Epoch: 023/050 | Batch 350/891 | Cost: 0.1590\n",
      "Epoch: 023/050 | Batch 400/891 | Cost: 0.3799\n",
      "Epoch: 023/050 | Batch 450/891 | Cost: 0.2423\n",
      "Epoch: 023/050 | Batch 500/891 | Cost: 0.1465\n",
      "Epoch: 023/050 | Batch 550/891 | Cost: 0.1157\n",
      "Epoch: 023/050 | Batch 600/891 | Cost: 0.2244\n",
      "Epoch: 023/050 | Batch 650/891 | Cost: 0.1930\n",
      "Epoch: 023/050 | Batch 700/891 | Cost: 0.1244\n",
      "Epoch: 023/050 | Batch 750/891 | Cost: 0.1410\n",
      "Epoch: 023/050 | Batch 800/891 | Cost: 0.1642\n",
      "Epoch: 023/050 | Batch 850/891 | Cost: 0.1734\n",
      "training accuracy: 96.90%\n",
      "valid accuracy: 91.63%\n",
      "Time elapsed: 12.39 min\n",
      "Epoch: 024/050 | Batch 000/891 | Cost: 0.0709\n",
      "Epoch: 024/050 | Batch 050/891 | Cost: 0.1248\n",
      "Epoch: 024/050 | Batch 100/891 | Cost: 0.1629\n",
      "Epoch: 024/050 | Batch 150/891 | Cost: 0.1777\n",
      "Epoch: 024/050 | Batch 200/891 | Cost: 0.2100\n",
      "Epoch: 024/050 | Batch 250/891 | Cost: 0.1991\n",
      "Epoch: 024/050 | Batch 300/891 | Cost: 0.4561\n",
      "Epoch: 024/050 | Batch 350/891 | Cost: 0.1529\n",
      "Epoch: 024/050 | Batch 400/891 | Cost: 0.1097\n",
      "Epoch: 024/050 | Batch 450/891 | Cost: 0.1213\n",
      "Epoch: 024/050 | Batch 500/891 | Cost: 0.1387\n",
      "Epoch: 024/050 | Batch 550/891 | Cost: 0.2177\n",
      "Epoch: 024/050 | Batch 600/891 | Cost: 0.1028\n",
      "Epoch: 024/050 | Batch 650/891 | Cost: 0.2664\n",
      "Epoch: 024/050 | Batch 700/891 | Cost: 0.0694\n",
      "Epoch: 024/050 | Batch 750/891 | Cost: 0.0847\n",
      "Epoch: 024/050 | Batch 800/891 | Cost: 0.1983\n",
      "Epoch: 024/050 | Batch 850/891 | Cost: 0.2498\n",
      "training accuracy: 97.16%\n",
      "valid accuracy: 91.93%\n",
      "Time elapsed: 13.42 min\n",
      "Epoch: 025/050 | Batch 000/891 | Cost: 0.1991\n",
      "Epoch: 025/050 | Batch 050/891 | Cost: 0.0666\n",
      "Epoch: 025/050 | Batch 100/891 | Cost: 0.1780\n",
      "Epoch: 025/050 | Batch 150/891 | Cost: 0.1563\n",
      "Epoch: 025/050 | Batch 200/891 | Cost: 0.0882\n",
      "Epoch: 025/050 | Batch 250/891 | Cost: 0.2989\n",
      "Epoch: 025/050 | Batch 300/891 | Cost: 0.1824\n",
      "Epoch: 025/050 | Batch 350/891 | Cost: 0.2966\n",
      "Epoch: 025/050 | Batch 400/891 | Cost: 0.2031\n",
      "Epoch: 025/050 | Batch 450/891 | Cost: 0.1180\n",
      "Epoch: 025/050 | Batch 500/891 | Cost: 0.3109\n",
      "Epoch: 025/050 | Batch 550/891 | Cost: 0.1684\n",
      "Epoch: 025/050 | Batch 600/891 | Cost: 0.0875\n",
      "Epoch: 025/050 | Batch 650/891 | Cost: 0.1391\n",
      "Epoch: 025/050 | Batch 700/891 | Cost: 0.1274\n",
      "Epoch: 025/050 | Batch 750/891 | Cost: 0.2153\n",
      "Epoch: 025/050 | Batch 800/891 | Cost: 0.1216\n",
      "Epoch: 025/050 | Batch 850/891 | Cost: 0.1828\n",
      "training accuracy: 97.05%\n",
      "valid accuracy: 91.38%\n",
      "Time elapsed: 14.47 min\n",
      "Epoch: 026/050 | Batch 000/891 | Cost: 0.1344\n",
      "Epoch: 026/050 | Batch 050/891 | Cost: 0.2940\n",
      "Epoch: 026/050 | Batch 100/891 | Cost: 0.1692\n",
      "Epoch: 026/050 | Batch 150/891 | Cost: 0.1281\n",
      "Epoch: 026/050 | Batch 200/891 | Cost: 0.1737\n",
      "Epoch: 026/050 | Batch 250/891 | Cost: 0.2194\n",
      "Epoch: 026/050 | Batch 300/891 | Cost: 0.3692\n",
      "Epoch: 026/050 | Batch 350/891 | Cost: 0.2095\n",
      "Epoch: 026/050 | Batch 400/891 | Cost: 0.2085\n",
      "Epoch: 026/050 | Batch 450/891 | Cost: 0.2011\n",
      "Epoch: 026/050 | Batch 500/891 | Cost: 0.2066\n",
      "Epoch: 026/050 | Batch 550/891 | Cost: 0.3383\n",
      "Epoch: 026/050 | Batch 600/891 | Cost: 0.2015\n",
      "Epoch: 026/050 | Batch 650/891 | Cost: 0.1520\n",
      "Epoch: 026/050 | Batch 700/891 | Cost: 0.0984\n",
      "Epoch: 026/050 | Batch 750/891 | Cost: 0.0933\n",
      "Epoch: 026/050 | Batch 800/891 | Cost: 0.2503\n",
      "Epoch: 026/050 | Batch 850/891 | Cost: 0.1500\n",
      "training accuracy: 97.30%\n",
      "valid accuracy: 91.88%\n",
      "Time elapsed: 15.54 min\n",
      "Epoch: 027/050 | Batch 000/891 | Cost: 0.1133\n",
      "Epoch: 027/050 | Batch 050/891 | Cost: 0.0566\n",
      "Epoch: 027/050 | Batch 100/891 | Cost: 0.1300\n",
      "Epoch: 027/050 | Batch 150/891 | Cost: 0.1017\n",
      "Epoch: 027/050 | Batch 200/891 | Cost: 0.1233\n",
      "Epoch: 027/050 | Batch 250/891 | Cost: 0.2639\n",
      "Epoch: 027/050 | Batch 300/891 | Cost: 0.1417\n",
      "Epoch: 027/050 | Batch 350/891 | Cost: 0.1526\n",
      "Epoch: 027/050 | Batch 400/891 | Cost: 0.1113\n",
      "Epoch: 027/050 | Batch 450/891 | Cost: 0.1807\n",
      "Epoch: 027/050 | Batch 500/891 | Cost: 0.2097\n",
      "Epoch: 027/050 | Batch 550/891 | Cost: 0.0656\n",
      "Epoch: 027/050 | Batch 600/891 | Cost: 0.1461\n",
      "Epoch: 027/050 | Batch 650/891 | Cost: 0.0721\n",
      "Epoch: 027/050 | Batch 700/891 | Cost: 0.1089\n",
      "Epoch: 027/050 | Batch 750/891 | Cost: 0.1491\n",
      "Epoch: 027/050 | Batch 800/891 | Cost: 0.2305\n",
      "Epoch: 027/050 | Batch 850/891 | Cost: 0.1258\n",
      "training accuracy: 97.38%\n",
      "valid accuracy: 92.00%\n",
      "Time elapsed: 16.61 min\n",
      "Epoch: 028/050 | Batch 000/891 | Cost: 0.0894\n",
      "Epoch: 028/050 | Batch 050/891 | Cost: 0.1093\n",
      "Epoch: 028/050 | Batch 100/891 | Cost: 0.1931\n",
      "Epoch: 028/050 | Batch 150/891 | Cost: 0.1843\n",
      "Epoch: 028/050 | Batch 200/891 | Cost: 0.1760\n",
      "Epoch: 028/050 | Batch 250/891 | Cost: 0.0717\n",
      "Epoch: 028/050 | Batch 300/891 | Cost: 0.1854\n",
      "Epoch: 028/050 | Batch 350/891 | Cost: 0.1044\n",
      "Epoch: 028/050 | Batch 400/891 | Cost: 0.1138\n",
      "Epoch: 028/050 | Batch 450/891 | Cost: 0.1639\n",
      "Epoch: 028/050 | Batch 500/891 | Cost: 0.1970\n",
      "Epoch: 028/050 | Batch 550/891 | Cost: 0.0855\n",
      "Epoch: 028/050 | Batch 600/891 | Cost: 0.0979\n",
      "Epoch: 028/050 | Batch 650/891 | Cost: 0.1288\n",
      "Epoch: 028/050 | Batch 700/891 | Cost: 0.1454\n",
      "Epoch: 028/050 | Batch 750/891 | Cost: 0.0631\n",
      "Epoch: 028/050 | Batch 800/891 | Cost: 0.1604\n",
      "Epoch: 028/050 | Batch 850/891 | Cost: 0.1495\n",
      "training accuracy: 97.54%\n",
      "valid accuracy: 91.87%\n",
      "Time elapsed: 17.68 min\n",
      "Epoch: 029/050 | Batch 000/891 | Cost: 0.0644\n",
      "Epoch: 029/050 | Batch 050/891 | Cost: 0.0699\n",
      "Epoch: 029/050 | Batch 100/891 | Cost: 0.2319\n",
      "Epoch: 029/050 | Batch 150/891 | Cost: 0.1196\n",
      "Epoch: 029/050 | Batch 200/891 | Cost: 0.0950\n",
      "Epoch: 029/050 | Batch 250/891 | Cost: 0.1323\n",
      "Epoch: 029/050 | Batch 300/891 | Cost: 0.2933\n",
      "Epoch: 029/050 | Batch 350/891 | Cost: 0.1934\n",
      "Epoch: 029/050 | Batch 400/891 | Cost: 0.0852\n",
      "Epoch: 029/050 | Batch 450/891 | Cost: 0.1402\n",
      "Epoch: 029/050 | Batch 500/891 | Cost: 0.2230\n",
      "Epoch: 029/050 | Batch 550/891 | Cost: 0.0998\n",
      "Epoch: 029/050 | Batch 600/891 | Cost: 0.1782\n",
      "Epoch: 029/050 | Batch 650/891 | Cost: 0.3283\n",
      "Epoch: 029/050 | Batch 700/891 | Cost: 0.2203\n",
      "Epoch: 029/050 | Batch 750/891 | Cost: 0.1579\n",
      "Epoch: 029/050 | Batch 800/891 | Cost: 0.1457\n",
      "Epoch: 029/050 | Batch 850/891 | Cost: 0.2025\n",
      "training accuracy: 97.45%\n",
      "valid accuracy: 91.53%\n",
      "Time elapsed: 18.74 min\n",
      "Epoch: 030/050 | Batch 000/891 | Cost: 0.0462\n",
      "Epoch: 030/050 | Batch 050/891 | Cost: 0.1564\n",
      "Epoch: 030/050 | Batch 100/891 | Cost: 0.0746\n",
      "Epoch: 030/050 | Batch 150/891 | Cost: 0.1384\n",
      "Epoch: 030/050 | Batch 200/891 | Cost: 0.2740\n",
      "Epoch: 030/050 | Batch 250/891 | Cost: 0.3271\n",
      "Epoch: 030/050 | Batch 300/891 | Cost: 0.1764\n",
      "Epoch: 030/050 | Batch 350/891 | Cost: 0.1777\n",
      "Epoch: 030/050 | Batch 400/891 | Cost: 0.0841\n",
      "Epoch: 030/050 | Batch 450/891 | Cost: 0.1597\n",
      "Epoch: 030/050 | Batch 500/891 | Cost: 0.1223\n",
      "Epoch: 030/050 | Batch 550/891 | Cost: 0.1083\n",
      "Epoch: 030/050 | Batch 600/891 | Cost: 0.1478\n",
      "Epoch: 030/050 | Batch 650/891 | Cost: 0.2959\n",
      "Epoch: 030/050 | Batch 700/891 | Cost: 0.1887\n",
      "Epoch: 030/050 | Batch 750/891 | Cost: 0.2498\n",
      "Epoch: 030/050 | Batch 800/891 | Cost: 0.1300\n",
      "Epoch: 030/050 | Batch 850/891 | Cost: 0.1651\n",
      "training accuracy: 97.41%\n",
      "valid accuracy: 91.53%\n",
      "Time elapsed: 19.80 min\n",
      "Epoch: 031/050 | Batch 000/891 | Cost: 0.2204\n",
      "Epoch: 031/050 | Batch 050/891 | Cost: 0.0253\n",
      "Epoch: 031/050 | Batch 100/891 | Cost: 0.2895\n",
      "Epoch: 031/050 | Batch 150/891 | Cost: 0.1715\n",
      "Epoch: 031/050 | Batch 200/891 | Cost: 0.1887\n",
      "Epoch: 031/050 | Batch 250/891 | Cost: 0.2059\n",
      "Epoch: 031/050 | Batch 300/891 | Cost: 0.0932\n",
      "Epoch: 031/050 | Batch 350/891 | Cost: 0.1699\n",
      "Epoch: 031/050 | Batch 400/891 | Cost: 0.0939\n",
      "Epoch: 031/050 | Batch 450/891 | Cost: 0.1887\n",
      "Epoch: 031/050 | Batch 500/891 | Cost: 0.1506\n",
      "Epoch: 031/050 | Batch 550/891 | Cost: 0.0940\n",
      "Epoch: 031/050 | Batch 600/891 | Cost: 0.0522\n",
      "Epoch: 031/050 | Batch 650/891 | Cost: 0.0805\n",
      "Epoch: 031/050 | Batch 700/891 | Cost: 0.1576\n",
      "Epoch: 031/050 | Batch 750/891 | Cost: 0.0976\n",
      "Epoch: 031/050 | Batch 800/891 | Cost: 0.2967\n",
      "Epoch: 031/050 | Batch 850/891 | Cost: 0.1926\n",
      "training accuracy: 97.74%\n",
      "valid accuracy: 91.80%\n",
      "Time elapsed: 20.79 min\n",
      "Epoch: 032/050 | Batch 000/891 | Cost: 0.2118\n",
      "Epoch: 032/050 | Batch 050/891 | Cost: 0.1500\n",
      "Epoch: 032/050 | Batch 100/891 | Cost: 0.0699\n",
      "Epoch: 032/050 | Batch 150/891 | Cost: 0.1424\n",
      "Epoch: 032/050 | Batch 200/891 | Cost: 0.2768\n",
      "Epoch: 032/050 | Batch 250/891 | Cost: 0.0965\n",
      "Epoch: 032/050 | Batch 300/891 | Cost: 0.0836\n",
      "Epoch: 032/050 | Batch 350/891 | Cost: 0.1566\n",
      "Epoch: 032/050 | Batch 400/891 | Cost: 0.1140\n",
      "Epoch: 032/050 | Batch 450/891 | Cost: 0.1286\n",
      "Epoch: 032/050 | Batch 500/891 | Cost: 0.1687\n",
      "Epoch: 032/050 | Batch 550/891 | Cost: 0.0647\n",
      "Epoch: 032/050 | Batch 600/891 | Cost: 0.0885\n",
      "Epoch: 032/050 | Batch 650/891 | Cost: 0.0491\n",
      "Epoch: 032/050 | Batch 700/891 | Cost: 0.0612\n",
      "Epoch: 032/050 | Batch 750/891 | Cost: 0.0645\n",
      "Epoch: 032/050 | Batch 800/891 | Cost: 0.2246\n",
      "Epoch: 032/050 | Batch 850/891 | Cost: 0.0900\n",
      "training accuracy: 97.75%\n",
      "valid accuracy: 91.77%\n",
      "Time elapsed: 21.75 min\n",
      "Epoch: 033/050 | Batch 000/891 | Cost: 0.1070\n",
      "Epoch: 033/050 | Batch 050/891 | Cost: 0.1982\n",
      "Epoch: 033/050 | Batch 100/891 | Cost: 0.1159\n",
      "Epoch: 033/050 | Batch 150/891 | Cost: 0.1398\n",
      "Epoch: 033/050 | Batch 200/891 | Cost: 0.0937\n",
      "Epoch: 033/050 | Batch 250/891 | Cost: 0.1015\n",
      "Epoch: 033/050 | Batch 300/891 | Cost: 0.0945\n",
      "Epoch: 033/050 | Batch 350/891 | Cost: 0.0534\n",
      "Epoch: 033/050 | Batch 400/891 | Cost: 0.1476\n",
      "Epoch: 033/050 | Batch 450/891 | Cost: 0.0937\n",
      "Epoch: 033/050 | Batch 500/891 | Cost: 0.2442\n",
      "Epoch: 033/050 | Batch 550/891 | Cost: 0.0817\n",
      "Epoch: 033/050 | Batch 600/891 | Cost: 0.2181\n",
      "Epoch: 033/050 | Batch 650/891 | Cost: 0.2121\n",
      "Epoch: 033/050 | Batch 700/891 | Cost: 0.1767\n",
      "Epoch: 033/050 | Batch 750/891 | Cost: 0.2248\n",
      "Epoch: 033/050 | Batch 800/891 | Cost: 0.1277\n",
      "Epoch: 033/050 | Batch 850/891 | Cost: 0.1004\n",
      "training accuracy: 97.88%\n",
      "valid accuracy: 91.63%\n",
      "Time elapsed: 22.69 min\n",
      "Epoch: 034/050 | Batch 000/891 | Cost: 0.1261\n",
      "Epoch: 034/050 | Batch 050/891 | Cost: 0.1267\n",
      "Epoch: 034/050 | Batch 100/891 | Cost: 0.1777\n",
      "Epoch: 034/050 | Batch 150/891 | Cost: 0.2866\n",
      "Epoch: 034/050 | Batch 200/891 | Cost: 0.0845\n",
      "Epoch: 034/050 | Batch 250/891 | Cost: 0.2171\n",
      "Epoch: 034/050 | Batch 300/891 | Cost: 0.1906\n",
      "Epoch: 034/050 | Batch 350/891 | Cost: 0.1531\n",
      "Epoch: 034/050 | Batch 400/891 | Cost: 0.0928\n",
      "Epoch: 034/050 | Batch 450/891 | Cost: 0.1674\n",
      "Epoch: 034/050 | Batch 500/891 | Cost: 0.2959\n",
      "Epoch: 034/050 | Batch 550/891 | Cost: 0.1654\n",
      "Epoch: 034/050 | Batch 600/891 | Cost: 0.2238\n",
      "Epoch: 034/050 | Batch 650/891 | Cost: 0.1358\n",
      "Epoch: 034/050 | Batch 700/891 | Cost: 0.0593\n",
      "Epoch: 034/050 | Batch 750/891 | Cost: 0.2061\n",
      "Epoch: 034/050 | Batch 800/891 | Cost: 0.0418\n",
      "Epoch: 034/050 | Batch 850/891 | Cost: 0.1814\n",
      "training accuracy: 97.77%\n",
      "valid accuracy: 91.53%\n",
      "Time elapsed: 23.67 min\n",
      "Epoch: 035/050 | Batch 000/891 | Cost: 0.2832\n",
      "Epoch: 035/050 | Batch 050/891 | Cost: 0.0631\n",
      "Epoch: 035/050 | Batch 100/891 | Cost: 0.1005\n",
      "Epoch: 035/050 | Batch 150/891 | Cost: 0.1677\n",
      "Epoch: 035/050 | Batch 200/891 | Cost: 0.0663\n",
      "Epoch: 035/050 | Batch 250/891 | Cost: 0.1370\n",
      "Epoch: 035/050 | Batch 300/891 | Cost: 0.1260\n",
      "Epoch: 035/050 | Batch 350/891 | Cost: 0.1642\n",
      "Epoch: 035/050 | Batch 400/891 | Cost: 0.1703\n",
      "Epoch: 035/050 | Batch 450/891 | Cost: 0.1147\n",
      "Epoch: 035/050 | Batch 500/891 | Cost: 0.1205\n",
      "Epoch: 035/050 | Batch 550/891 | Cost: 0.1352\n",
      "Epoch: 035/050 | Batch 600/891 | Cost: 0.1017\n",
      "Epoch: 035/050 | Batch 650/891 | Cost: 0.2116\n",
      "Epoch: 035/050 | Batch 700/891 | Cost: 0.1301\n",
      "Epoch: 035/050 | Batch 750/891 | Cost: 0.1565\n",
      "Epoch: 035/050 | Batch 800/891 | Cost: 0.0610\n",
      "Epoch: 035/050 | Batch 850/891 | Cost: 0.1000\n",
      "training accuracy: 98.02%\n",
      "valid accuracy: 91.92%\n",
      "Time elapsed: 24.75 min\n",
      "Epoch: 036/050 | Batch 000/891 | Cost: 0.2945\n",
      "Epoch: 036/050 | Batch 050/891 | Cost: 0.0929\n",
      "Epoch: 036/050 | Batch 100/891 | Cost: 0.1919\n",
      "Epoch: 036/050 | Batch 150/891 | Cost: 0.1328\n",
      "Epoch: 036/050 | Batch 200/891 | Cost: 0.0948\n",
      "Epoch: 036/050 | Batch 250/891 | Cost: 0.0330\n",
      "Epoch: 036/050 | Batch 300/891 | Cost: 0.1418\n",
      "Epoch: 036/050 | Batch 350/891 | Cost: 0.3359\n",
      "Epoch: 036/050 | Batch 400/891 | Cost: 0.3079\n",
      "Epoch: 036/050 | Batch 450/891 | Cost: 0.1771\n",
      "Epoch: 036/050 | Batch 500/891 | Cost: 0.0698\n",
      "Epoch: 036/050 | Batch 550/891 | Cost: 0.1285\n",
      "Epoch: 036/050 | Batch 600/891 | Cost: 0.0174\n",
      "Epoch: 036/050 | Batch 650/891 | Cost: 0.1377\n",
      "Epoch: 036/050 | Batch 700/891 | Cost: 0.1203\n",
      "Epoch: 036/050 | Batch 750/891 | Cost: 0.0861\n",
      "Epoch: 036/050 | Batch 800/891 | Cost: 0.0767\n",
      "Epoch: 036/050 | Batch 850/891 | Cost: 0.1800\n",
      "training accuracy: 97.97%\n",
      "valid accuracy: 91.88%\n",
      "Time elapsed: 25.82 min\n",
      "Epoch: 037/050 | Batch 000/891 | Cost: 0.3566\n",
      "Epoch: 037/050 | Batch 050/891 | Cost: 0.1634\n",
      "Epoch: 037/050 | Batch 100/891 | Cost: 0.1186\n",
      "Epoch: 037/050 | Batch 150/891 | Cost: 0.1233\n",
      "Epoch: 037/050 | Batch 200/891 | Cost: 0.1115\n",
      "Epoch: 037/050 | Batch 250/891 | Cost: 0.1204\n",
      "Epoch: 037/050 | Batch 300/891 | Cost: 0.0447\n",
      "Epoch: 037/050 | Batch 350/891 | Cost: 0.1045\n",
      "Epoch: 037/050 | Batch 400/891 | Cost: 0.1046\n",
      "Epoch: 037/050 | Batch 450/891 | Cost: 0.0250\n",
      "Epoch: 037/050 | Batch 500/891 | Cost: 0.0988\n",
      "Epoch: 037/050 | Batch 550/891 | Cost: 0.1314\n",
      "Epoch: 037/050 | Batch 600/891 | Cost: 0.1060\n",
      "Epoch: 037/050 | Batch 650/891 | Cost: 0.1120\n",
      "Epoch: 037/050 | Batch 700/891 | Cost: 0.1844\n",
      "Epoch: 037/050 | Batch 750/891 | Cost: 0.0897\n",
      "Epoch: 037/050 | Batch 800/891 | Cost: 0.2487\n",
      "Epoch: 037/050 | Batch 850/891 | Cost: 0.1493\n",
      "training accuracy: 97.98%\n",
      "valid accuracy: 91.48%\n",
      "Time elapsed: 26.89 min\n",
      "Epoch: 038/050 | Batch 000/891 | Cost: 0.1361\n",
      "Epoch: 038/050 | Batch 050/891 | Cost: 0.1114\n",
      "Epoch: 038/050 | Batch 100/891 | Cost: 0.1495\n",
      "Epoch: 038/050 | Batch 150/891 | Cost: 0.0973\n",
      "Epoch: 038/050 | Batch 200/891 | Cost: 0.1874\n",
      "Epoch: 038/050 | Batch 250/891 | Cost: 0.1043\n",
      "Epoch: 038/050 | Batch 300/891 | Cost: 0.1514\n",
      "Epoch: 038/050 | Batch 350/891 | Cost: 0.2377\n",
      "Epoch: 038/050 | Batch 400/891 | Cost: 0.2675\n",
      "Epoch: 038/050 | Batch 450/891 | Cost: 0.0705\n",
      "Epoch: 038/050 | Batch 500/891 | Cost: 0.1921\n",
      "Epoch: 038/050 | Batch 550/891 | Cost: 0.0772\n",
      "Epoch: 038/050 | Batch 600/891 | Cost: 0.2542\n",
      "Epoch: 038/050 | Batch 650/891 | Cost: 0.0602\n",
      "Epoch: 038/050 | Batch 700/891 | Cost: 0.1468\n",
      "Epoch: 038/050 | Batch 750/891 | Cost: 0.0620\n",
      "Epoch: 038/050 | Batch 800/891 | Cost: 0.1213\n",
      "Epoch: 038/050 | Batch 850/891 | Cost: 0.1046\n",
      "training accuracy: 98.07%\n",
      "valid accuracy: 91.80%\n",
      "Time elapsed: 27.93 min\n",
      "Epoch: 039/050 | Batch 000/891 | Cost: 0.1133\n",
      "Epoch: 039/050 | Batch 050/891 | Cost: 0.1479\n",
      "Epoch: 039/050 | Batch 100/891 | Cost: 0.1279\n",
      "Epoch: 039/050 | Batch 150/891 | Cost: 0.1508\n",
      "Epoch: 039/050 | Batch 200/891 | Cost: 0.1695\n",
      "Epoch: 039/050 | Batch 250/891 | Cost: 0.1512\n",
      "Epoch: 039/050 | Batch 300/891 | Cost: 0.1059\n",
      "Epoch: 039/050 | Batch 350/891 | Cost: 0.0721\n",
      "Epoch: 039/050 | Batch 400/891 | Cost: 0.0856\n",
      "Epoch: 039/050 | Batch 450/891 | Cost: 0.1215\n",
      "Epoch: 039/050 | Batch 500/891 | Cost: 0.0628\n",
      "Epoch: 039/050 | Batch 550/891 | Cost: 0.1136\n",
      "Epoch: 039/050 | Batch 600/891 | Cost: 0.0866\n",
      "Epoch: 039/050 | Batch 650/891 | Cost: 0.0740\n",
      "Epoch: 039/050 | Batch 700/891 | Cost: 0.0922\n",
      "Epoch: 039/050 | Batch 750/891 | Cost: 0.0684\n",
      "Epoch: 039/050 | Batch 800/891 | Cost: 0.1036\n",
      "Epoch: 039/050 | Batch 850/891 | Cost: 0.3993\n",
      "training accuracy: 98.14%\n",
      "valid accuracy: 91.50%\n",
      "Time elapsed: 28.98 min\n",
      "Epoch: 040/050 | Batch 000/891 | Cost: 0.1712\n",
      "Epoch: 040/050 | Batch 050/891 | Cost: 0.1368\n",
      "Epoch: 040/050 | Batch 100/891 | Cost: 0.2130\n",
      "Epoch: 040/050 | Batch 150/891 | Cost: 0.2074\n",
      "Epoch: 040/050 | Batch 200/891 | Cost: 0.1886\n",
      "Epoch: 040/050 | Batch 250/891 | Cost: 0.0763\n",
      "Epoch: 040/050 | Batch 300/891 | Cost: 0.1250\n",
      "Epoch: 040/050 | Batch 350/891 | Cost: 0.0659\n",
      "Epoch: 040/050 | Batch 400/891 | Cost: 0.1597\n",
      "Epoch: 040/050 | Batch 450/891 | Cost: 0.0973\n",
      "Epoch: 040/050 | Batch 500/891 | Cost: 0.1974\n",
      "Epoch: 040/050 | Batch 550/891 | Cost: 0.0470\n",
      "Epoch: 040/050 | Batch 600/891 | Cost: 0.0981\n",
      "Epoch: 040/050 | Batch 650/891 | Cost: 0.2160\n",
      "Epoch: 040/050 | Batch 700/891 | Cost: 0.0991\n",
      "Epoch: 040/050 | Batch 750/891 | Cost: 0.1553\n",
      "Epoch: 040/050 | Batch 800/891 | Cost: 0.2289\n",
      "Epoch: 040/050 | Batch 850/891 | Cost: 0.1656\n",
      "training accuracy: 98.16%\n",
      "valid accuracy: 91.67%\n",
      "Time elapsed: 30.01 min\n",
      "Epoch: 041/050 | Batch 000/891 | Cost: 0.1532\n",
      "Epoch: 041/050 | Batch 050/891 | Cost: 0.1516\n",
      "Epoch: 041/050 | Batch 100/891 | Cost: 0.1026\n",
      "Epoch: 041/050 | Batch 150/891 | Cost: 0.2094\n",
      "Epoch: 041/050 | Batch 200/891 | Cost: 0.0773\n",
      "Epoch: 041/050 | Batch 250/891 | Cost: 0.0909\n",
      "Epoch: 041/050 | Batch 300/891 | Cost: 0.1079\n",
      "Epoch: 041/050 | Batch 350/891 | Cost: 0.2061\n",
      "Epoch: 041/050 | Batch 400/891 | Cost: 0.0633\n",
      "Epoch: 041/050 | Batch 450/891 | Cost: 0.1377\n",
      "Epoch: 041/050 | Batch 500/891 | Cost: 0.2176\n",
      "Epoch: 041/050 | Batch 550/891 | Cost: 0.1144\n",
      "Epoch: 041/050 | Batch 600/891 | Cost: 0.1907\n",
      "Epoch: 041/050 | Batch 650/891 | Cost: 0.1184\n",
      "Epoch: 041/050 | Batch 700/891 | Cost: 0.0938\n",
      "Epoch: 041/050 | Batch 750/891 | Cost: 0.0866\n",
      "Epoch: 041/050 | Batch 800/891 | Cost: 0.1442\n",
      "Epoch: 041/050 | Batch 850/891 | Cost: 0.0893\n",
      "training accuracy: 98.25%\n",
      "valid accuracy: 91.70%\n",
      "Time elapsed: 31.05 min\n",
      "Epoch: 042/050 | Batch 000/891 | Cost: 0.1878\n",
      "Epoch: 042/050 | Batch 050/891 | Cost: 0.1001\n",
      "Epoch: 042/050 | Batch 100/891 | Cost: 0.0742\n",
      "Epoch: 042/050 | Batch 150/891 | Cost: 0.1685\n",
      "Epoch: 042/050 | Batch 200/891 | Cost: 0.0812\n",
      "Epoch: 042/050 | Batch 250/891 | Cost: 0.1662\n",
      "Epoch: 042/050 | Batch 300/891 | Cost: 0.0969\n",
      "Epoch: 042/050 | Batch 350/891 | Cost: 0.1765\n",
      "Epoch: 042/050 | Batch 400/891 | Cost: 0.0659\n",
      "Epoch: 042/050 | Batch 450/891 | Cost: 0.1227\n",
      "Epoch: 042/050 | Batch 500/891 | Cost: 0.0946\n",
      "Epoch: 042/050 | Batch 550/891 | Cost: 0.1164\n",
      "Epoch: 042/050 | Batch 600/891 | Cost: 0.1121\n",
      "Epoch: 042/050 | Batch 650/891 | Cost: 0.1068\n",
      "Epoch: 042/050 | Batch 700/891 | Cost: 0.0964\n",
      "Epoch: 042/050 | Batch 750/891 | Cost: 0.1052\n",
      "Epoch: 042/050 | Batch 800/891 | Cost: 0.0914\n",
      "Epoch: 042/050 | Batch 850/891 | Cost: 0.1908\n",
      "training accuracy: 98.24%\n",
      "valid accuracy: 91.52%\n",
      "Time elapsed: 32.08 min\n",
      "Epoch: 043/050 | Batch 000/891 | Cost: 0.1148\n",
      "Epoch: 043/050 | Batch 050/891 | Cost: 0.0874\n",
      "Epoch: 043/050 | Batch 100/891 | Cost: 0.1539\n",
      "Epoch: 043/050 | Batch 150/891 | Cost: 0.1270\n",
      "Epoch: 043/050 | Batch 200/891 | Cost: 0.0444\n",
      "Epoch: 043/050 | Batch 250/891 | Cost: 0.0705\n",
      "Epoch: 043/050 | Batch 300/891 | Cost: 0.1335\n",
      "Epoch: 043/050 | Batch 350/891 | Cost: 0.2058\n",
      "Epoch: 043/050 | Batch 400/891 | Cost: 0.1839\n",
      "Epoch: 043/050 | Batch 450/891 | Cost: 0.1798\n",
      "Epoch: 043/050 | Batch 500/891 | Cost: 0.1855\n",
      "Epoch: 043/050 | Batch 550/891 | Cost: 0.1608\n",
      "Epoch: 043/050 | Batch 600/891 | Cost: 0.1785\n",
      "Epoch: 043/050 | Batch 650/891 | Cost: 0.1823\n",
      "Epoch: 043/050 | Batch 700/891 | Cost: 0.1660\n",
      "Epoch: 043/050 | Batch 750/891 | Cost: 0.2193\n",
      "Epoch: 043/050 | Batch 800/891 | Cost: 0.1133\n",
      "Epoch: 043/050 | Batch 850/891 | Cost: 0.0708\n",
      "training accuracy: 98.25%\n",
      "valid accuracy: 91.60%\n",
      "Time elapsed: 33.12 min\n",
      "Epoch: 044/050 | Batch 000/891 | Cost: 0.1061\n",
      "Epoch: 044/050 | Batch 050/891 | Cost: 0.1410\n",
      "Epoch: 044/050 | Batch 100/891 | Cost: 0.0963\n",
      "Epoch: 044/050 | Batch 150/891 | Cost: 0.0455\n",
      "Epoch: 044/050 | Batch 200/891 | Cost: 0.1148\n",
      "Epoch: 044/050 | Batch 250/891 | Cost: 0.0956\n",
      "Epoch: 044/050 | Batch 300/891 | Cost: 0.1357\n",
      "Epoch: 044/050 | Batch 350/891 | Cost: 0.0914\n",
      "Epoch: 044/050 | Batch 400/891 | Cost: 0.1779\n",
      "Epoch: 044/050 | Batch 450/891 | Cost: 0.0951\n",
      "Epoch: 044/050 | Batch 500/891 | Cost: 0.0805\n",
      "Epoch: 044/050 | Batch 550/891 | Cost: 0.0946\n",
      "Epoch: 044/050 | Batch 600/891 | Cost: 0.2519\n",
      "Epoch: 044/050 | Batch 650/891 | Cost: 0.0587\n",
      "Epoch: 044/050 | Batch 700/891 | Cost: 0.1026\n",
      "Epoch: 044/050 | Batch 750/891 | Cost: 0.0970\n",
      "Epoch: 044/050 | Batch 800/891 | Cost: 0.1420\n",
      "Epoch: 044/050 | Batch 850/891 | Cost: 0.0799\n",
      "training accuracy: 98.27%\n",
      "valid accuracy: 91.50%\n",
      "Time elapsed: 34.09 min\n",
      "Epoch: 045/050 | Batch 000/891 | Cost: 0.1535\n",
      "Epoch: 045/050 | Batch 050/891 | Cost: 0.1314\n",
      "Epoch: 045/050 | Batch 100/891 | Cost: 0.0673\n",
      "Epoch: 045/050 | Batch 150/891 | Cost: 0.1049\n",
      "Epoch: 045/050 | Batch 200/891 | Cost: 0.0908\n",
      "Epoch: 045/050 | Batch 250/891 | Cost: 0.2232\n",
      "Epoch: 045/050 | Batch 300/891 | Cost: 0.0698\n",
      "Epoch: 045/050 | Batch 350/891 | Cost: 0.0505\n",
      "Epoch: 045/050 | Batch 400/891 | Cost: 0.0682\n",
      "Epoch: 045/050 | Batch 450/891 | Cost: 0.1018\n",
      "Epoch: 045/050 | Batch 500/891 | Cost: 0.0461\n",
      "Epoch: 045/050 | Batch 550/891 | Cost: 0.1451\n",
      "Epoch: 045/050 | Batch 600/891 | Cost: 0.0264\n",
      "Epoch: 045/050 | Batch 650/891 | Cost: 0.0608\n",
      "Epoch: 045/050 | Batch 700/891 | Cost: 0.1043\n",
      "Epoch: 045/050 | Batch 750/891 | Cost: 0.0882\n",
      "Epoch: 045/050 | Batch 800/891 | Cost: 0.1163\n",
      "Epoch: 045/050 | Batch 850/891 | Cost: 0.2396\n",
      "training accuracy: 98.29%\n",
      "valid accuracy: 91.40%\n",
      "Time elapsed: 35.03 min\n",
      "Epoch: 046/050 | Batch 000/891 | Cost: 0.0788\n",
      "Epoch: 046/050 | Batch 050/891 | Cost: 0.0304\n",
      "Epoch: 046/050 | Batch 100/891 | Cost: 0.0826\n",
      "Epoch: 046/050 | Batch 150/891 | Cost: 0.1860\n",
      "Epoch: 046/050 | Batch 200/891 | Cost: 0.1872\n",
      "Epoch: 046/050 | Batch 250/891 | Cost: 0.0610\n",
      "Epoch: 046/050 | Batch 300/891 | Cost: 0.1037\n",
      "Epoch: 046/050 | Batch 350/891 | Cost: 0.1565\n",
      "Epoch: 046/050 | Batch 400/891 | Cost: 0.1976\n",
      "Epoch: 046/050 | Batch 450/891 | Cost: 0.1081\n",
      "Epoch: 046/050 | Batch 500/891 | Cost: 0.1374\n",
      "Epoch: 046/050 | Batch 550/891 | Cost: 0.0744\n",
      "Epoch: 046/050 | Batch 600/891 | Cost: 0.0795\n",
      "Epoch: 046/050 | Batch 650/891 | Cost: 0.1045\n",
      "Epoch: 046/050 | Batch 700/891 | Cost: 0.2454\n",
      "Epoch: 046/050 | Batch 750/891 | Cost: 0.1897\n",
      "Epoch: 046/050 | Batch 800/891 | Cost: 0.0899\n",
      "Epoch: 046/050 | Batch 850/891 | Cost: 0.1644\n",
      "training accuracy: 98.52%\n",
      "valid accuracy: 91.80%\n",
      "Time elapsed: 35.97 min\n",
      "Epoch: 047/050 | Batch 000/891 | Cost: 0.0844\n",
      "Epoch: 047/050 | Batch 050/891 | Cost: 0.1276\n",
      "Epoch: 047/050 | Batch 100/891 | Cost: 0.1050\n",
      "Epoch: 047/050 | Batch 150/891 | Cost: 0.0994\n",
      "Epoch: 047/050 | Batch 200/891 | Cost: 0.0310\n",
      "Epoch: 047/050 | Batch 250/891 | Cost: 0.1233\n",
      "Epoch: 047/050 | Batch 300/891 | Cost: 0.1956\n",
      "Epoch: 047/050 | Batch 350/891 | Cost: 0.1355\n",
      "Epoch: 047/050 | Batch 400/891 | Cost: 0.0901\n",
      "Epoch: 047/050 | Batch 450/891 | Cost: 0.1141\n",
      "Epoch: 047/050 | Batch 500/891 | Cost: 0.1127\n",
      "Epoch: 047/050 | Batch 550/891 | Cost: 0.1333\n",
      "Epoch: 047/050 | Batch 600/891 | Cost: 0.0607\n",
      "Epoch: 047/050 | Batch 650/891 | Cost: 0.0458\n",
      "Epoch: 047/050 | Batch 700/891 | Cost: 0.0623\n",
      "Epoch: 047/050 | Batch 750/891 | Cost: 0.1557\n",
      "Epoch: 047/050 | Batch 800/891 | Cost: 0.0998\n",
      "Epoch: 047/050 | Batch 850/891 | Cost: 0.1906\n",
      "training accuracy: 98.39%\n",
      "valid accuracy: 91.62%\n",
      "Time elapsed: 36.90 min\n",
      "Epoch: 048/050 | Batch 000/891 | Cost: 0.0498\n",
      "Epoch: 048/050 | Batch 050/891 | Cost: 0.1280\n",
      "Epoch: 048/050 | Batch 100/891 | Cost: 0.3360\n",
      "Epoch: 048/050 | Batch 150/891 | Cost: 0.1495\n",
      "Epoch: 048/050 | Batch 200/891 | Cost: 0.1255\n",
      "Epoch: 048/050 | Batch 250/891 | Cost: 0.0538\n",
      "Epoch: 048/050 | Batch 300/891 | Cost: 0.1525\n",
      "Epoch: 048/050 | Batch 350/891 | Cost: 0.0628\n",
      "Epoch: 048/050 | Batch 400/891 | Cost: 0.0923\n",
      "Epoch: 048/050 | Batch 450/891 | Cost: 0.2230\n",
      "Epoch: 048/050 | Batch 500/891 | Cost: 0.3083\n",
      "Epoch: 048/050 | Batch 550/891 | Cost: 0.0439\n",
      "Epoch: 048/050 | Batch 600/891 | Cost: 0.0468\n",
      "Epoch: 048/050 | Batch 650/891 | Cost: 0.0583\n",
      "Epoch: 048/050 | Batch 700/891 | Cost: 0.1199\n",
      "Epoch: 048/050 | Batch 750/891 | Cost: 0.0736\n",
      "Epoch: 048/050 | Batch 800/891 | Cost: 0.1704\n",
      "Epoch: 048/050 | Batch 850/891 | Cost: 0.1210\n",
      "training accuracy: 98.62%\n",
      "valid accuracy: 91.67%\n",
      "Time elapsed: 37.94 min\n",
      "Epoch: 049/050 | Batch 000/891 | Cost: 0.0950\n",
      "Epoch: 049/050 | Batch 050/891 | Cost: 0.0561\n",
      "Epoch: 049/050 | Batch 100/891 | Cost: 0.0741\n",
      "Epoch: 049/050 | Batch 150/891 | Cost: 0.1510\n",
      "Epoch: 049/050 | Batch 200/891 | Cost: 0.0725\n",
      "Epoch: 049/050 | Batch 250/891 | Cost: 0.1095\n",
      "Epoch: 049/050 | Batch 300/891 | Cost: 0.0607\n",
      "Epoch: 049/050 | Batch 350/891 | Cost: 0.1911\n",
      "Epoch: 049/050 | Batch 400/891 | Cost: 0.0869\n",
      "Epoch: 049/050 | Batch 450/891 | Cost: 0.0695\n",
      "Epoch: 049/050 | Batch 500/891 | Cost: 0.1631\n",
      "Epoch: 049/050 | Batch 550/891 | Cost: 0.2730\n",
      "Epoch: 049/050 | Batch 600/891 | Cost: 0.0997\n",
      "Epoch: 049/050 | Batch 650/891 | Cost: 0.0588\n",
      "Epoch: 049/050 | Batch 700/891 | Cost: 0.0969\n",
      "Epoch: 049/050 | Batch 750/891 | Cost: 0.1929\n",
      "Epoch: 049/050 | Batch 800/891 | Cost: 0.0639\n",
      "Epoch: 049/050 | Batch 850/891 | Cost: 0.1441\n",
      "training accuracy: 98.67%\n",
      "valid accuracy: 91.80%\n",
      "Time elapsed: 38.98 min\n",
      "Epoch: 050/050 | Batch 000/891 | Cost: 0.0646\n",
      "Epoch: 050/050 | Batch 050/891 | Cost: 0.1085\n",
      "Epoch: 050/050 | Batch 100/891 | Cost: 0.1356\n",
      "Epoch: 050/050 | Batch 150/891 | Cost: 0.0649\n",
      "Epoch: 050/050 | Batch 200/891 | Cost: 0.1520\n",
      "Epoch: 050/050 | Batch 250/891 | Cost: 0.0987\n",
      "Epoch: 050/050 | Batch 300/891 | Cost: 0.1930\n",
      "Epoch: 050/050 | Batch 350/891 | Cost: 0.2051\n",
      "Epoch: 050/050 | Batch 400/891 | Cost: 0.1187\n",
      "Epoch: 050/050 | Batch 450/891 | Cost: 0.0401\n",
      "Epoch: 050/050 | Batch 500/891 | Cost: 0.0716\n",
      "Epoch: 050/050 | Batch 550/891 | Cost: 0.1372\n",
      "Epoch: 050/050 | Batch 600/891 | Cost: 0.1621\n",
      "Epoch: 050/050 | Batch 650/891 | Cost: 0.1026\n",
      "Epoch: 050/050 | Batch 700/891 | Cost: 0.1087\n",
      "Epoch: 050/050 | Batch 750/891 | Cost: 0.1647\n",
      "Epoch: 050/050 | Batch 800/891 | Cost: 0.1104\n",
      "Epoch: 050/050 | Batch 850/891 | Cost: 0.0536\n",
      "training accuracy: 98.72%\n",
      "valid accuracy: 91.85%\n",
      "Time elapsed: 40.01 min\n",
      "Total Training Time: 40.01 min\n",
      "Test accuracy: 91.26%\n"
     ]
    }
   ],
   "source": [
    "start_time = time.time()\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    model.train()\n",
    "    for batch_idx, batch_data in enumerate(train_loader):\n",
    "        \n",
    "        text, text_lengths = batch_data.content\n",
    "        \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits = model(text, text_lengths).squeeze(1)\n",
    "        cost = F.cross_entropy(logits, batch_data.classlabel.long())\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n",
    "                   f'Batch {batch_idx:03d}/{len(train_loader):03d} | '\n",
    "                   f'Cost: {cost:.4f}')\n",
    "\n",
    "    with torch.set_grad_enabled(False):\n",
    "        print(f'training accuracy: '\n",
    "              f'{compute_accuracy(model, train_loader, DEVICE):.2f}%'\n",
    "              f'\\nvalid accuracy: '\n",
    "              f'{compute_accuracy(model, valid_loader, DEVICE):.2f}%')\n",
    "        \n",
    "    print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')\n",
    "    \n",
    "print(f'Total Training Time: {(time.time() - start_time)/60:.2f} min')\n",
    "print(f'Test accuracy: {compute_accuracy(model, test_loader, DEVICE):.2f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluating on some new text that has been collected from recent news articles and is not part of the training or test sets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jt55pscgFdKZ"
   },
   "outputs": [],
   "source": [
    "import spacy\n",
    "nlp = spacy.load('en')\n",
    "\n",
    "\n",
    "map_dictionary = {\n",
    "    0: \"World\",\n",
    "    1: \"Sports\",\n",
    "    2: \"Business\",\n",
    "    3:\"Sci/Tech\",\n",
    "}\n",
    "\n",
    "\n",
    "def predict_class(model, sentence, min_len=4):\n",
    "    # Somewhat based on\n",
    "    # https://github.com/bentrevett/pytorch-sentiment-analysis/\n",
    "    # blob/master/5%20-%20Multi-class%20Sentiment%20Analysis.ipynb\n",
    "    model.eval()\n",
    "    tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n",
    "    if len(tokenized) < min_len:\n",
    "        tokenized += ['<pad>'] * (min_len - len(tokenized))\n",
    "    indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n",
    "    length = [len(indexed)]\n",
    "    tensor = torch.LongTensor(indexed).to(DEVICE)\n",
    "    tensor = tensor.unsqueeze(1)\n",
    "    length_tensor = torch.LongTensor(length)\n",
    "    preds = model(tensor, length_tensor)\n",
    "    preds = torch.softmax(preds, dim=1)\n",
    "    \n",
    "    proba, class_label = preds.max(dim=1)\n",
    "    return proba.item(), class_label.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Class Label: 2 -> Business\n",
      "Probability: 0.878576934337616\n"
     ]
    }
   ],
   "source": [
    "text = \"\"\"\n",
    "The windfall follows a tender offer by Z Holdings, which is controlled by SoftBank’s domestic wireless unit, \n",
    "for half of Zozo’s shares this month.\n",
    "\"\"\"\n",
    "\n",
    "proba, pred_label = predict_class(model, text)\n",
    "\n",
    "print(f'Class Label: {pred_label} -> {map_dictionary[pred_label]}')\n",
    "print(f'Probability: {proba}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Class Label: 0 -> World\n",
      "Probability: 0.9969592094421387\n"
     ]
    }
   ],
   "source": [
    "text = \"\"\"\n",
    "EU data regulator issues first-ever sanction of an EU institution, \n",
    "against the European parliament over its use of US-based NationBuilder to process voter data \n",
    "\"\"\"\n",
    "\n",
    "proba, pred_label = predict_class(model, text)\n",
    "\n",
    "print(f'Class Label: {pred_label} -> {map_dictionary[pred_label]}')\n",
    "print(f'Probability: {proba}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Class Label: 2 -> Business\n",
      "Probability: 0.9953342080116272\n"
     ]
    }
   ],
   "source": [
    "text = \"\"\"\n",
    "LG announces CEO Jo Seong-jin will be replaced by Brian Kwon Dec. 1, amid 2020 \n",
    "leadership shakeup and LG smartphone division's 18th straight quarterly loss\n",
    "\"\"\"\n",
    "\n",
    "proba, pred_label = predict_class(model, text)\n",
    "\n",
    "print(f'Class Label: {pred_label} -> {map_dictionary[pred_label]}')\n",
    "print(f'Probability: {proba}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7lRusB3dF80X"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy     1.17.2\n",
      "pandas    0.24.2\n",
      "torch     1.3.0\n",
      "torchtext 0.4.0\n",
      "spacy     2.2.3\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "rnn_lstm_packed_imdb.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
